!pip install -q wandb
!pip install -q datasets transformers
!pip install -q plotly-express
!pip install -U gdown -q
from huggingface_hub import notebook_login
notebook_login()
import wandb
wandb.login()
wandb.init(project="vit-classification-eurosat", entity="polejowska")
from datasets import load_dataset, load_metric
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, TrainingArguments, Trainer
import torch
from torchvision.transforms import (
Compose,
Normalize,
Resize,
ToTensor,
)
import numpy as np
from tqdm import tqdm
from PIL import Image
import requests
import zipfile
from io import BytesIO
import gdown
import tensorflow as tf
from tensorflow import keras
import plotly.express as px
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("whitegrid")
sns.set_context("notebook", font_scale=1.5, rc={"lines.linewidth": 2.5})
dataset = load_dataset("imagefolder", data_files="https://madm.dfki.de/files/sentinel/EuroSAT.zip")
print(f"Dataset structure: {dataset}\n")
print(f"Number of training examples: {len(dataset['train'])}\n")
print(f"Dataset sample (image, label): {dataset['train'][0]}\n")
print(f"Dataset features: {dataset['train'].features}\n")
print(f"Class labels: {dataset['train'].features['label'].names}\n")
labels = dataset["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
label2id[label] = i
id2label[i] = label
wandb.config.update({"class_labels": dataset["train"].features["label"].names})
wandb.config.update({"num_train_examples": len(dataset["train"])})
def plot_class_distribution(dataset, id2label, dataset_name="dataset"):
fig = px.histogram(
x=[id2label[label] for label in dataset["label"]],
title=f"Distribution of classes in the {dataset_name}",
)
fig.update_layout(xaxis_title="Class", yaxis_title="Number of examples")
fig.show()
return fig
enitre_dataset_fig = plot_class_distribution(dataset["train"], id2label)
wandb.log({"class distribution in the entire dataset": enitre_dataset_fig})
def display_random_images(dataset, label2id, id2label):
# display four random images from the dataset using plotly
fig = plt.figure(figsize=(10, 10))
for i in range(4):
random_image = np.random.randint(0, len(dataset))
image = dataset[random_image]["image"]
label = dataset[random_image]["label"]
class_name = id2label[label]
ax = fig.add_subplot(2, 2, i + 1)
ax.imshow(image)
ax.set_title(f"Class: {label} ({class_name})")
ax.axis("off")
plt.show()
wandb.log({"random_images": fig})
display_random_images(dataset["train"], label2id, id2label)
model_checkpoint = "microsoft/swin-tiny-patch4-window7-224"
# model_checkpoint = "facebook/convnext-tiny-224"
# model_checkpoint = "google/vit-base-patch16-224-in21k"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
resize_value = (feature_extractor.size['height'], feature_extractor.size['width'])
# resize_value = (feature_extractor.size['shortest_edge'], feature_extractor.size['shortest_edge'])
print(f"Resize value: {resize_value}")
data_transforms = Compose(
[
Resize(resize_value),
ToTensor(),
Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
]
)
def add_pixel_values_feature(batch):
batch["pixel_values"] = [
data_transforms(image.convert("RGB")) for image in batch["image"]
]
return batch
split = dataset["train"].train_test_split(test_size=0.2)
train_dataset = split["train"]
validation_dataset_split = split["test"].train_test_split(test_size=0.3)
validation_dataset = validation_dataset_split["train"]
test_dataset = validation_dataset_split["test"]
train_dataset = train_dataset.select(range(0, len(train_dataset), 5))
validation_dataset = validation_dataset.select(range(0, len(validation_dataset), 5))
test_dataset = test_dataset.select(range(0, len(test_dataset), 5))
print(f"Length of training dataset: {len(train_dataset)}")
print(f"Length of validation dataset: {len(validation_dataset)}")
print(f"Length of test dataset: {len(test_dataset)}")
train_dataset_fig = plot_class_distribution(train_dataset, id2label, dataset_name="training dataset")
wandb.log({"class distribution in the training dataset": train_dataset_fig})
validation_dataset_fig = plot_class_distribution(validation_dataset, id2label, dataset_name="validation dataset")
wandb.log({"class distribution in the validation dataset": validation_dataset_fig})
test_dataset_fig = plot_class_distribution(test_dataset, id2label, dataset_name="test dataset")
wandb.log({"class distribution in the test dataset": test_dataset_fig})
wandb.config.update({"num_train_examples": len(train_dataset)})
wandb.config.update({"num_validation_examples": len(validation_dataset)})
wandb.config.update({"num_test_examples": len(test_dataset)})
train_dataset.set_transform(add_pixel_values_feature)
validation_dataset.set_transform(add_pixel_values_feature)
test_dataset.set_transform(add_pixel_values_feature)
def create_table(dataset):
table = wandb.Table(columns=["image", "label", "class name"])
for i in tqdm(range(len(dataset))):
image, label = dataset[i]["image"], dataset[i]["label"]
table.add_data(wandb.Image(image), label, id2label[label])
return table
train_table = create_table(train_dataset)
validation_table = create_table(validation_dataset)
test_table = create_table(test_dataset)
wandb.log({"train_dataset": train_table})
wandb.log({"validation_dataset": validation_table})
wandb.log({"test_dataset": test_table})
model = AutoModelForImageClassification.from_pretrained(
model_checkpoint,
label2id=label2id,
id2label=id2label,
ignore_mismatched_sizes=True,
)
MODEL_NAME = model_checkpoint.split("/")[-1]
NUM_TRAIN_EPOCHS = 10
LEARNING_RATE = 5e-5
BATCH_SIZE = 32
STRATEGY = "epoch"
wandb.run.name = f"{MODEL_NAME} (epochs: {NUM_TRAIN_EPOCHS})"
args = TrainingArguments(
f"{MODEL_NAME}-eurosat",
remove_unused_columns=False,
evaluation_strategy=STRATEGY,
save_strategy=STRATEGY,
learning_rate=LEARNING_RATE,
per_device_train_batch_size=BATCH_SIZE,
gradient_accumulation_steps=4,
per_device_eval_batch_size=BATCH_SIZE,
num_train_epochs=NUM_TRAIN_EPOCHS,
warmup_ratio=0.1,
logging_steps=10,
load_best_model_at_end=True,
metric_for_best_model="accuracy",
report_to="wandb",
push_to_hub=True,
)
def collate_fn(batches):
pixel_values = torch.stack([batch["pixel_values"] for batch in batches])
labels = torch.tensor([batch["label"] for batch in batches])
return {"pixel_values": pixel_values, "labels": labels}
accuracy_metric = load_metric("accuracy")
def compute_metrics(eval_pred):
predictions = np.argmax(eval_pred.predictions, axis=1)
return accuracy_metric.compute(predictions=predictions, references=eval_pred.label_ids)
trainer = Trainer(
model,
args,
train_dataset=train_dataset,
eval_dataset=validation_dataset,
tokenizer=feature_extractor,
data_collator=collate_fn,
compute_metrics=compute_metrics,
)
trainer_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", trainer_results.metrics)
trainer.save_metrics("train", trainer_results.metrics)
trainer.save_state()
metrics = trainer.evaluate()
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
trainer.push_to_hub()
def create_table_with_predictions(dataset, predictions):
table = wandb.Table(columns=["image", "label", "class name", "prediction", "prediction class name"])
for i in tqdm(range(len(dataset))):
image, label = dataset[i]["image"], dataset[i]["label"]
table.add_data(wandb.Image(image), label, id2label[label], predictions[i], id2label[predictions[i]])
return table
test_predictions = np.argmax(trainer.predict(test_dataset).predictions, axis=1)
test_table_with_predictions = create_table_with_predictions(test_dataset, test_predictions)
wandb.log({"test_table_with_predictions": test_table_with_predictions})
test_table_with_predictions_artifact = wandb.Artifact(
name="test_table_with_predictions",
type="test_table_with_predictions",
description="A table with predictions on the test dataset",
metadata={
"num_test_examples": len(test_dataset),
},
)
test_table_with_predictions_artifact.add(test_table_with_predictions, "test_table_with_predictions")
wandb.log_artifact(test_table_with_predictions_artifact)
confusion_matrix = wandb.plot.confusion_matrix(
probs=None,
y_true=test_dataset[:]["label"],
preds=test_predictions,
class_names=list(id2label.values()),
)
wandb.log({"confusion_matrix": confusion_matrix})
import plotly.graph_objects as go
from sklearn.metrics import confusion_matrix
def plot_confusion_matrix(cm, class_names):
fig = go.Figure(data=go.Heatmap(z=cm, x=class_names, y=class_names))
fig.update_layout(
title="Confusion Matrix",
xaxis_title="Predicted label",
yaxis_title="True label",
annotations=[go.layout.Annotation(text=str(round(z, 2)), x=x, y=y,
font_size=14, showarrow=False)
for x, y, z in zip(np.tile(class_names, len(class_names)), np.repeat(class_names, len(class_names)), cm.flatten())],
)
fig.show()
return fig
cm = confusion_matrix(test_dataset[:]["label"], test_predictions)
cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]
cm_plot = plot_confusion_matrix(cm, list(id2label.values()))
wandb.log({"confusion_matrix (plotly)": cm_plot})
from PIL import Image
import requests
imgs_urls = [
"https://huggingface.co/polejowska/swin-tiny-patch4-window7-224-eurosat/resolve/main/GoogleEarth/Zrzut%20ekranu%202022-12-10%20110307.png",
"https://huggingface.co/polejowska/swin-tiny-patch4-window7-224-eurosat/resolve/main/GoogleEarth/Zrzut%20ekranu%202022-12-10%20110454.png",
"https://huggingface.co/polejowska/swin-tiny-patch4-window7-224-eurosat/resolve/main/GoogleEarth/Zrzut%20ekranu%202022-12-10%20110708.png",
"https://huggingface.co/polejowska/swin-tiny-patch4-window7-224-eurosat/resolve/main/GoogleEarth/Zrzut%20ekranu%202022-12-10%20114336.png",
"https://huggingface.co/polejowska/swin-tiny-patch4-window7-224-eurosat/resolve/main/GoogleEarth/Zrzut%20ekranu%202022-12-10%20113158.png",
]
imgs = []
for img_url in imgs_urls:
image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
imgs.append(image)
def predict(feature_extractor, model, images, device):
resize_value = (224, 224)
inference_data_transforms = Compose(
[
Resize(resize_value),
ToTensor(),
Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
]
)
encodings = [feature_extractor(image.convert("RGB"), return_tensors="pt").to(device) for image in images]
predicted_class_idxs = []
for i, encoding in enumerate(encodings):
with torch.no_grad():
outputs = model(**encoding)
logits = outputs.logits
predicted_class_idxs.append(logits.argmax(-1).item())
predicted_class_name = model.config.id2label[predicted_class_idxs[i]]
plt.imshow(images[i])
plt.title(f"{predicted_class_name}")
plt.axis("off")
plt.show()
return predicted_class_idxs
device = "cuda:0"
predicted_class_idxs = predict(feature_extractor, model, imgs, device)
wandb.log({"inference_images":
[wandb.Image(
image,
caption=f"Predicted class: {predicted_class_idxs[i]} {model.config.id2label[predicted_class_idxs[i]]}"
) for i, image in enumerate(imgs)]})
The DINO model is used for visualizing attention maps. Attention maps are overlayed on the input images.
def load_image_from_url(url):
response = requests.get(url)
image = Image.open(BytesIO(response.content)).convert('RGB')
return image
img_url = 'https://huggingface.co/polejowska/swin-tiny-patch4-window7-224-finetuned-eurosat/resolve/main/GoogleEarth/Zrzut%20ekranu%202022-12-10%20110307.png'
image = load_image_from_url(img_url)
imgs_urls = [
"https://huggingface.co/polejowska/swin-tiny-patch4-window7-224-eurosat/resolve/main/GoogleEarth/Zrzut%20ekranu%202022-12-10%20110307.png",
"https://huggingface.co/polejowska/swin-tiny-patch4-window7-224-eurosat/resolve/main/GoogleEarth/Zrzut%20ekranu%202022-12-10%20110454.png",
"https://huggingface.co/polejowska/swin-tiny-patch4-window7-224-eurosat/resolve/main/GoogleEarth/Zrzut%20ekranu%202022-12-10%20110708.png",
"https://huggingface.co/polejowska/swin-tiny-patch4-window7-224-eurosat/resolve/main/GoogleEarth/Zrzut%20ekranu%202022-12-10%20114336.png",
"https://huggingface.co/polejowska/swin-tiny-patch4-window7-224-eurosat/resolve/main/GoogleEarth/Zrzut%20ekranu%202022-12-10%20113158.png",
]
imgs = [load_image_from_url(img_url) for img_url in imgs_urls]
fig, axs = plt.subplots(1, len(imgs), figsize=(20, 20))
for i, img in enumerate(imgs):
axs[i].imshow(img)
axs[i].axis('off')
plt.show()
def preprocess_image(image, size=224):
crop_layer = keras.layers.CenterCrop(224, 224)
norm_layer = keras.layers.Normalization(
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
variance=[(0.229 * 255) ** 2, (0.224 * 255) ** 2, (0.225 * 255) ** 2],
)
rescale_layer = keras.layers.Rescaling(scale=1.0 / 127.5, offset=-1)
image = np.array(image)
image = tf.expand_dims(image, 0)
resize_size = int((256 / 224) * size)
image = tf.image.resize(image, (resize_size, resize_size), method="bicubic")
image = crop_layer(image)
image = norm_layer(image)
return image.numpy()
preprocessed_image = preprocess_image(image)
preprocessed_images = [preprocess_image(img) for img in imgs]
plt.figure(figsize=(10, 10))
for i, image in enumerate(preprocessed_images):
ax = plt.subplot(1, len(preprocessed_images), i + 1)
plt.imshow(image[0])
plt.axis("off")
def get_gdrive_model(model_id: str) -> tf.keras.Model:
model_path = gdown.download(id=model_id, quiet=False)
with zipfile.ZipFile(model_path, "r") as zip_ref:
zip_ref.extractall()
model_name = model_path.split(".")[0]
inputs = keras.Input((224, 224, 3))
model = keras.models.load_model(model_name, compile=False)
outputs, attention_weights = model(inputs)
return keras.Model(inputs, outputs=[outputs, attention_weights])
def get_model(id):
loaded_model = get_gdrive_model(id)
return loaded_model
vit_dino_base16 = get_model("16_1oDm0PeCGJ_KGBG5UKVN7TsAtiRNrN")
predictions, attention_score_dict = vit_dino_base16.predict(preprocessed_image)
predictions_attention_score_dict = {}
for i, image in enumerate(preprocessed_images):
predictions, attention_score_dict = vit_dino_base16.predict(image)
predictions_attention_score_dict[i] = attention_score_dict
def attention_heatmap(attention_score_dict, image, patch_size=16):
num_tokens = 1
num_heads = 12
attention_score_list = list(attention_score_dict.keys())
attention_score_list.sort(key=lambda x: int(x.split("_")[-2]), reverse=True)
w_featmap = image.shape[2] // patch_size
h_featmap = image.shape[1] // patch_size
attention_scores = attention_score_dict[attention_score_list[0]]
attentions = attention_scores[0, :, 0, num_tokens:].reshape(num_heads, -1)
attentions = attentions.reshape(num_heads, w_featmap, h_featmap)
attentions = attentions.transpose((1, 2, 0))
attentions = tf.image.resize(
attentions, size=(h_featmap * patch_size, w_featmap * patch_size)
)
return attentions
def denormalize_image(image):
in1k_mean = tf.constant([0.485 * 255, 0.456 * 255, 0.406 * 255])
in1k_std = tf.constant([0.229 * 255, 0.224 * 255, 0.225 * 255])
image = (image * in1k_std) + in1k_mean
image = image / 255.0
image = tf.clip_by_value(image, 0.0, 1.0)
return image
def plot_attention_heatmap(attention_score_dict, preprocessed_img_orig):
image = denormalize_image(preprocessed_img_orig)
attentions = attention_heatmap(attention_score_dict, image)
fig, axes = plt.subplots(nrows=3, ncols=4, figsize=(13, 13))
img_count = 0
for i in range(3):
for j in range(4):
if img_count < len(attentions):
axes[i, j].imshow(image[0])
axes[i, j].imshow(attentions[..., img_count], cmap="inferno", alpha=0.6)
axes[i, j].title.set_text(f"Attention head: {img_count}")
axes[i, j].axis("off")
img_count += 1
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
wandb.log({"attention_heatmap": [wandb.Image(data)]})
wandb.log({"attention_heads attention maps": fig})
return fig
plot_attention_heatmap(predictions_attention_score_dict[0], preprocessed_images[0])
plot_attention_heatmap(predictions_attention_score_dict[3], preprocessed_images[3])
plot_attention_heatmap(predictions_attention_score_dict[2], preprocessed_images[2])
plot_attention_heatmap(predictions_attention_score_dict[1], preprocessed_images[1])
test_dataset_maps = [img['image'] for img in test_dataset]
test_dataset_maps
processed_images_testdataset = [preprocess_image(img) for img in test_dataset_maps]
for i, image in enumerate(processed_images_testdataset):
predictions, attention_score_dict = vit_dino_base16.predict(image)
plot_attention_heatmap(attention_score_dict, image)
from transformers import pipeline
repo_name = "polejowska/vit-base-patch16-224-in21k-eurosat"
pipe = pipeline("image-classification", repo_name)
pipe(imgs[0])
wandb.finish()